from typing import Any, TypeVar
from ..modules import Module


class WeightNorm:
    name: str = ...
    dim: int = ...

    def __init__(self, name: str, dim: int) -> None: ...

    # TODO Make return type more specific
    def compute_weight(self, module: Module) -> Any: ...

    @staticmethod
    def apply(module: Module, name: str, dim: int) -> 'WeightNorm': ...

    def remove(self, module: Module) -> None: ...

    def __call__(self, module: Module, inputs: Any) -> None: ...


T_module = TypeVar('T_module', bound=Module)


def weight_norm(module: T_module, name: str = ..., dim: int = ...) -> T_module: ...


def remove_weight_norm(module: T_module, name: str = ...) -> T_module: ...
